# Copyright © 2023-2024 Apple Inc.

import json
import os
import unittest
from dataclasses import asdict
from pathlib import Path

import mlx.core as mx
import mlx_whisper
import mlx_whisper.audio as audio
import mlx_whisper.decoding as decoding
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import convert, load_torch_model, quantize
from mlx.utils import tree_flatten

MODEL_NAME = "tiny"
MLX_FP32_MODEL_PATH = "mlx_models/tiny_fp32"
MLX_FP16_MODEL_PATH = "mlx_models/tiny_fp16"
MLX_4BITS_MODEL_PATH = "mlx_models/tiny_quantized_4bits"
TEST_AUDIO = "mlx_whisper/assets/ls_test.flac"


def _save_model(save_dir, weights, config):
    mlx_path = Path(save_dir)
    mlx_path.mkdir(parents=True, exist_ok=True)

    # Save weights
    np.savez(str(mlx_path / "weights.npz"), **weights)

    # Save config.json with model_type
    with open(str(mlx_path / "config.json"), "w") as f:
        config["model_type"] = "whisper"
        json.dump(config, f, indent=4)

    config.pop("model_type", None)


def load_torch_and_mlx():
    torch_model = load_torch_model(MODEL_NAME)

    fp32_model = convert(MODEL_NAME, dtype=mx.float32)
    config = asdict(fp32_model.dims)
    weights = dict(tree_flatten(fp32_model.parameters()))
    _save_model(MLX_FP32_MODEL_PATH, weights, config)

    fp16_model = convert(MODEL_NAME, dtype=mx.float16)
    config = asdict(fp16_model.dims)
    weights = dict(tree_flatten(fp16_model.parameters()))
    _save_model(MLX_FP16_MODEL_PATH, weights, config)

    args = type("", (), {})()
    args.q_group_size = 64
    args.q_bits = 4
    weights, config = quantize(weights, config, args)
    _save_model(MLX_4BITS_MODEL_PATH, weights, config)

    return torch_model, fp32_model, fp16_model


def forward_torch(model, mels, tokens):
    mels = torch.Tensor(mels).to(torch.float32)
    tokens = torch.Tensor(tokens).to(torch.int32)
    with torch.no_grad():
        logits = model.forward(mels, tokens)
    return logits.numpy()


def forward_mlx(model, mels, tokens):
    mels = mx.array(mels.transpose(0, 2, 1))
    tokens = mx.array(tokens, mx.int32)
    logits = model(mels, tokens)
    return np.array(logits)


class TestWhisper(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        _, cls.model, _ = load_torch_and_mlx()
        data = audio.load_audio(TEST_AUDIO)
        data = audio.pad_or_trim(data)
        cls.mels = audio.log_mel_spectrogram(data)

    def test_torch_mlx(self):
        np.random.seed(10)

        torch_model = load_torch_model(MODEL_NAME)
        dims = torch_model.dims

        mels = np.random.randn(1, dims.n_mels, 3_000)
        tokens = np.random.randint(0, dims.n_vocab, (1, 20))

        torch_logits = forward_torch(torch_model, mels, tokens)

        mlx_logits = forward_mlx(self.model, mels, tokens)

        self.assertTrue(np.allclose(torch_logits, mlx_logits, atol=1e-2, rtol=1e-2))

    def test_fp16(self):
        mlx_model = load_models.load_model(MLX_FP16_MODEL_PATH, mx.float16)
        dims = mlx_model.dims
        mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
        tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
        logits = mlx_model(mels, tokens)
        self.assertEqual(logits.dtype, mx.float16)

    def test_quantized_4bits(self):
        mlx_model = load_models.load_model(MLX_4BITS_MODEL_PATH, mx.float16)
        dims = mlx_model.dims
        mels = mx.array(np.random.randn(1, 3_000, dims.n_mels), mx.float16)
        tokens = mx.array(np.random.randint(0, dims.n_vocab, (1, 20)), mx.int32)
        logits = mlx_model(mels, tokens)
        # Here, we just test if 4-bit models can forward, as the quantized tiny models struggle with accurate transcription
        self.assertEqual(logits.dtype, mx.float16)

    def test_decode_lang(self):
        options = decoding.DecodingOptions(task="lang_id", fp16=False)
        result = decoding.decode(self.model, self.mels, options)
        self.assertEqual(result.language, "en")
        self.assertEqual(len(result.language_probs), 99)
        self.assertAlmostEqual(
            result.language_probs["en"], 0.9947282671928406, places=5
        )

    def test_decode_greedy(self):
        result = decoding.decode(self.model, self.mels, fp16=False)
        self.assertEqual(result.language, "en")
        self.assertEqual(
            result.tokens,
            [
                50364,
                1396,
                264,
                665,
                5133,
                23109,
                25462,
                264,
                6582,
                293,
                750,
                632,
                42841,
                292,
                370,
                938,
                294,
                4054,
                293,
                12653,
                356,
                50620,
                50620,
                23563,
                322,
                3312,
                13,
                50680,
            ],
        )
        self.assertEqual(
            result.text,
            (
                "Then the good soul openly sorted the boat and she "
                "had buoyed so long in secret and bravely stretched on alone."
            ),
        )
        self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
        self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
        self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)

        # Small temp should give the same results
        result = decoding.decode(self.model, self.mels, temperature=1e-8, fp16=False)

        self.assertEqual(
            result.text,
            (
                "Then the good soul openly sorted the boat and she "
                "had buoyed so long in secret and bravely stretched on alone."
            ),
        )
        self.assertAlmostEqual(result.avg_logprob, -0.4975455382774616, places=3)
        self.assertAlmostEqual(result.no_speech_prob, 0.009631240740418434, places=4)
        self.assertAlmostEqual(result.compression_ratio, 1.2359550561797752)

    def test_transcribe(self):
        result = mlx_whisper.transcribe(
            TEST_AUDIO, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
        )
        self.assertEqual(
            result["text"],
            (
                " Then the good soul openly sorted the boat and she "
                "had buoyed so long in secret and bravely stretched on alone."
            ),
        )

    def test_transcribe_alice(self):
        audio_file = os.path.join(
            os.path.expanduser("~"),
            ".cache/whisper/alice.mp3",
        )
        if not os.path.exists(audio_file):
            print("To run this test download the alice in wonderland audiobook:")
            print("bash path_to_whisper_repo/whisper/assets/download_alice.sh")
            return

        result = mlx_whisper.transcribe(
            audio_file, path_or_hf_repo=MLX_FP32_MODEL_PATH, fp16=False
        )
        self.assertEqual(len(result["text"]), 10920)
        self.assertEqual(result["language"], "en")
        self.assertEqual(len(result["segments"]), 77)

        expected_5 = {
            "id": 5,
            "seek": 2800,
            "start": 40.0,
            "end": 46.0,
            "text": " Oh my poor little feet, I wonder who will put on your shoes and stockings for you now tears.",
            "tokens": [
                50964,
                876,
                452,
                4716,
                707,
                3521,
                11,
                286,
                2441,
                567,
                486,
                829,
                322,
                428,
                6654,
                293,
                4127,
                1109,
                337,
                291,
                586,
                10462,
                13,
                51264,
            ],
            "temperature": 0.0,
            "avg_logprob": -0.19670599699020386,
            "compression_ratio": 1.5991379310344827,
            "no_speech_prob": 0.09746722131967545,
        }

        expected_73 = {
            "id": 73,
            "seek": 70700,
            "start": 707.0,
            "end": 715.0,
            "text": " let us get to the shore, and then I'll tell you my history, and you'll understand why it is that I hate cats and dogs.",
            "tokens": [
                50364,
                718,
                505,
                483,
                281,
                264,
                17805,
                11,
                293,
                550,
                286,
                603,
                980,
                291,
                452,
                2503,
                11,
                293,
                291,
                603,
                1223,
                983,
                309,
                307,
                300,
                286,
                4700,
                11111,
                293,
                7197,
                13,
                50764,
            ],
            "temperature": 0.0,
            "avg_logprob": -0.1350895343440594,
            "compression_ratio": 1.6208333333333333,
            "no_speech_prob": 0.009053784422576427,
        }

        def check_segment(seg, expected):
            for k, v in expected.items():
                if isinstance(v, float):
                    self.assertAlmostEqual(seg[k], v, places=2)
                else:
                    self.assertEqual(seg[k], v)

        # Randomly check a couple of segments
        check_segment(result["segments"][5], expected_5)
        check_segment(result["segments"][73], expected_73)

    def test_transcribe_word_level_timestamps_confidence_scores(self):
        result = mlx_whisper.transcribe(
            TEST_AUDIO,
            path_or_hf_repo=MLX_FP16_MODEL_PATH,
            word_timestamps=True,
        )

        # result predicted with openai-whisper
        expected_0 = [
            {
                "word": " Then",
                "start": 0.0,
                "end": 0.94,
                "probability": 0.855542778968811,
            },
            {
                "word": " the",
                "start": 0.94,
                "end": 1.12,
                "probability": 0.6500106453895569,
            },
            {
                "word": " good",
                "start": 1.12,
                "end": 1.32,
                "probability": 0.5503873825073242,
            },
            {
                "word": " soul",
                "start": 1.32,
                "end": 1.56,
                "probability": 0.46757155656814575,
            },
            {
                "word": " openly",
                "start": 1.56,
                "end": 2.0,
                "probability": 0.9840946793556213,
            },
            {
                "word": " sorted",
                "start": 2.0,
                "end": 2.38,
                "probability": 0.24167272448539734,
            },
            {
                "word": " the",
                "start": 2.38,
                "end": 2.58,
                "probability": 0.9875414967536926,
            },
            {
                "word": " boat",
                "start": 2.58,
                "end": 2.8,
                "probability": 0.5856029391288757,
            },
            {
                "word": " and",
                "start": 2.8,
                "end": 2.98,
                "probability": 0.913351833820343,
            },
            {
                "word": " she",
                "start": 2.98,
                "end": 3.1,
                "probability": 0.9913808703422546,
            },
            {
                "word": " had",
                "start": 3.1,
                "end": 3.32,
                "probability": 0.9952940344810486,
            },
            {
                "word": " buoyed",
                "start": 3.32,
                "end": 3.58,
                "probability": 0.6411589980125427,
            },
            {
                "word": " so",
                "start": 3.58,
                "end": 3.8,
                "probability": 0.9682658314704895,
            },
            {
                "word": " long",
                "start": 3.8,
                "end": 4.06,
                "probability": 0.9953522682189941,
            },
            {
                "word": " in",
                "start": 4.06,
                "end": 4.26,
                "probability": 0.6745936870574951,
            },
            {
                "word": " secret",
                "start": 4.26,
                "end": 4.56,
                "probability": 0.9905064702033997,
            },
            {
                "word": " and",
                "start": 4.56,
                "end": 4.9,
                "probability": 0.856008768081665,
            },
            {
                "word": " bravely",
                "start": 4.9,
                "end": 5.28,
                "probability": 0.8477402329444885,
            },
        ]

        def check_words(words, expected_words):
            for word, expected_word in zip(words, expected_words):
                for k, v in expected_word.items():
                    if isinstance(v, float):
                        self.assertAlmostEqual(word[k], v, places=1)
                    else:
                        self.assertEqual(word[k], v)

        # Randomly check a couple of segments
        check_words(result["segments"][0]["words"], expected_0)


class TestAudio(unittest.TestCase):
    def test_load(self):
        data = audio.load_audio(TEST_AUDIO)
        data_8k = audio.load_audio(TEST_AUDIO, 8000)
        n = 106640
        self.assertTrue(data.shape, (n,))
        self.assertTrue(data.dtype, np.float32)
        self.assertTrue(data_8k.shape, (n // 2,))

    def test_pad(self):
        data = audio.load_audio(TEST_AUDIO)
        data = audio.pad_or_trim(data, 20_000)
        self.assertTrue(data.shape, [20_000])

    def test_mel_spec(self):
        mels = audio.log_mel_spectrogram(TEST_AUDIO)
        self.assertTrue(mels.shape, [80, 400])
        self.assertTrue(mels.dtype, mx.float32)


if __name__ == "__main__":
    unittest.main()
